# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from enum import Enum
from typing import Any, Union

import pydantic

from haystack import logging
from haystack.core.errors import DeserializationError, SerializationError
from haystack.core.serialization import generate_qualified_class_name, import_class_by_name
from haystack.utils import deserialize_callable, serialize_callable

logger = logging.getLogger(__name__)

_PRIMITIVE_TO_SCHEMA_MAP = {type(None): "null", bool: "boolean", int: "integer", float: "number", str: "string"}


def serialize_class_instance(obj: Any) -> dict[str, Any]:
    """
    Serializes an object that has a `to_dict` method into a dictionary.

    :param obj:
        The object to be serialized.
    :returns:
        A dictionary representation of the object.
    :raises SerializationError:
        If the object does not have a `to_dict` method.
    """
    if not hasattr(obj, "to_dict"):
        raise SerializationError(f"Object of class '{type(obj).__name__}' does not have a 'to_dict' method")

    output = obj.to_dict()
    return {"type": generate_qualified_class_name(type(obj)), "data": output}


def deserialize_class_instance(data: dict[str, Any]) -> Any:
    """
    Deserializes an object from a dictionary representation generated by `auto_serialize_class_instance`.

    :param data:
        The dictionary to deserialize from.
    :returns:
        The deserialized object.
    :raises DeserializationError:
        If the serialization data is malformed, the class type cannot be imported, or the
        class does not have a `from_dict` method.
    """
    if "type" not in data:
        raise DeserializationError("Missing 'type' in serialization data")
    if "data" not in data:
        raise DeserializationError("Missing 'data' in serialization data")

    try:
        obj_class = import_class_by_name(data["type"])
    except ImportError as e:
        raise DeserializationError(f"Class '{data['type']}' not correctly imported") from e

    if not hasattr(obj_class, "from_dict"):
        raise DeserializationError(f"Class '{data['type']}' does not have a 'from_dict' method")

    return obj_class.from_dict(data["data"])


def _serialize_value_with_schema(payload: Any) -> dict[str, Any]:  # pylint: disable=too-many-return-statements # noqa: PLR0911
    """
    Serializes a value into a schema-aware format suitable for storage or transmission.

    The output format separates the schema information from the actual data, making it easier
    to deserialize complex nested structures correctly.

    The function handles:
    - Objects with to_dict() methods (e.g. dataclasses)
    - Objects with __dict__ attributes
    - Dictionaries
    - Lists, tuples, and sets. Lists with mixed types are not supported.
    - Primitive types (str, int, float, bool, None)

    :param payload: The value to serialize (can be any type)
    :returns: The serialized dict representation of the given value. Contains two keys:
        - "serialization_schema": Contains type information for each field.
        - "serialized_data": Contains the actual data in a simplified format.

    """
    # Handle pydantic
    if isinstance(payload, pydantic.BaseModel):
        type_name = generate_qualified_class_name(type(payload))
        return {"serialization_schema": {"type": type_name}, "serialized_data": payload.model_dump()}

    # Handle dictionary case - iterate through fields
    elif isinstance(payload, dict):
        schema: dict[str, Any] = {}
        data: dict[str, Any] = {}

        for field, val in payload.items():
            # Recursively serialize each field
            serialized_value = _serialize_value_with_schema(val)
            schema[field] = serialized_value["serialization_schema"]
            data[field] = serialized_value["serialized_data"]

        return {"serialization_schema": {"type": "object", "properties": schema}, "serialized_data": data}

    # Handle array case - iterate through elements
    elif isinstance(payload, (list, tuple, set)):
        # Serialize each item in the array
        serialized_list = []
        for item in payload:
            serialized_value = _serialize_value_with_schema(item)
            serialized_list.append(serialized_value["serialized_data"])

        # Determine item type from first element (if any)
        # NOTE: We do not support mixed-type lists
        if payload:
            first = next(iter(payload))
            item_schema = _serialize_value_with_schema(first)
            base_schema = {"type": "array", "items": item_schema["serialization_schema"]}
        else:
            base_schema = {"type": "array", "items": {}}

        # Add JSON Schema properties to infer sets and tuples
        if isinstance(payload, set):
            base_schema["uniqueItems"] = True
        elif isinstance(payload, tuple):
            base_schema["minItems"] = len(payload)
            base_schema["maxItems"] = len(payload)

        return {"serialization_schema": base_schema, "serialized_data": serialized_list}

    # Handle Haystack style objects (e.g. dataclasses and Components)
    elif hasattr(payload, "to_dict") and callable(payload.to_dict):
        type_name = generate_qualified_class_name(type(payload))
        schema = {"type": type_name}
        return {"serialization_schema": schema, "serialized_data": payload.to_dict()}

    # Handle callable functions serialization
    elif callable(payload) and not isinstance(payload, type):
        serialized = serialize_callable(payload)
        return {"serialization_schema": {"type": "typing.Callable"}, "serialized_data": serialized}

    # Handle Enums
    elif isinstance(payload, Enum):
        type_name = generate_qualified_class_name(type(payload))
        return {"serialization_schema": {"type": type_name}, "serialized_data": payload.name}

    # Handle arbitrary objects with __dict__
    elif hasattr(payload, "__dict__"):
        type_name = generate_qualified_class_name(type(payload))
        schema = {"type": type_name}
        serialized_data = {}
        for key, value in vars(payload).items():
            serialized_value = _serialize_value_with_schema(value)
            serialized_data[key] = serialized_value["serialized_data"]
        return {"serialization_schema": schema, "serialized_data": serialized_data}

    # Handle primitives
    else:
        schema = {"type": _primitive_schema_type(payload)}
        return {"serialization_schema": schema, "serialized_data": payload}


def _primitive_schema_type(value: Any) -> str:
    """
    Helper function to determine the schema type for primitive values.
    """
    for py_type, schema_value in _PRIMITIVE_TO_SCHEMA_MAP.items():
        if isinstance(value, py_type):
            return schema_value
    logger.warning(
        "Unsupported primitive type '{value_type}', falling back to 'string'", value_type=type(value).__name__
    )
    return "string"  # fallback


def _deserialize_value_with_schema(serialized: dict[str, Any]) -> Any:
    """
    Deserializes a value with schema information back to its original form.

    Takes a dict of the form:
      {
         "serialization_schema": {"type": "integer"} or {"type": "object", "properties": {...}},
         "serialized_data": <the actual data>
      }

    NOTE: For array types we only support homogeneous lists (all elements of the same type).

    :param serialized: The serialized dict with schema and data.
    :returns: The deserialized value in its original form.
    """

    if not serialized or "serialization_schema" not in serialized or "serialized_data" not in serialized:
        raise DeserializationError(
            f"Invalid format of passed serialized payload. Expected a dictionary with keys "
            f"'serialization_schema' and 'serialized_data'. Got: {serialized}"
        )
    schema = serialized["serialization_schema"]
    data = serialized["serialized_data"]

    schema_type = schema.get("type")

    if not schema_type:
        # for backward compatibility till Haystack 2.16 we use legacy implementation
        raise DeserializationError(
            "Missing 'type' key in 'serialization_schema'. This likely indicates that you're using a serialized "
            "State object created with a version of Haystack older than 2.15.0. "
            "Support for the old serialization format is removed in Haystack 2.16.0. "
            "Please upgrade to the new serialization format to ensure forward compatibility."
        )

    # Handle object case (dictionary with properties)
    if schema_type == "object":
        properties = schema["properties"]
        result: dict[str, Any] = {}
        for field, raw_value in data.items():
            field_schema = properties[field]
            # Recursively deserialize each field - avoid creating temporary dict
            result[field] = _deserialize_value_with_schema(
                {"serialization_schema": field_schema, "serialized_data": raw_value}
            )
        return result

    # Handle array case
    if schema_type == "array":
        # Deserialize each item
        deserialized_items = [
            _deserialize_value_with_schema({"serialization_schema": schema["items"], "serialized_data": item})
            for item in data
        ]
        final_array: Union[list, set, tuple]
        # Is a set if uniqueItems is True
        if schema.get("uniqueItems") is True:
            final_array = set(deserialized_items)
        # Is a tuple if minItems and maxItems are set
        elif schema.get("minItems") is not None and schema.get("maxItems") is not None:
            final_array = tuple(deserialized_items)
        else:
            # Otherwise, it's a list
            final_array = list(deserialized_items)
        return final_array

    # Handle primitive types
    if schema_type in _PRIMITIVE_TO_SCHEMA_MAP.values():
        return data

    # Handle callable functions
    if schema_type == "typing.Callable":
        return deserialize_callable(data)

    # Handle custom class types
    return _deserialize_value({"type": schema_type, "data": data})


def _deserialize_value(value: dict[str, Any]) -> Any:
    """
    Helper function to deserialize values from their envelope format {"type": T, "data": D}.

    This handles:
    - Custom classes (with a from_dict method)
    - Enums
    - Fallback for arbitrary classes (sets attributes on a blank instance)

    :param value: The value to deserialize
    :returns:
        The deserialized value
    :raises DeserializationError:
        If the type cannot be imported or the value is not valid for the type.
    """
    # 1) Envelope case
    value_type = value["type"]
    payload = value["data"]

    # Custom class where value_type is a qualified class name
    cls = import_class_by_name(value_type)

    # try from_dict (e.g. Haystack dataclasses and Components)
    if hasattr(cls, "from_dict") and callable(cls.from_dict):
        return cls.from_dict(payload)

    # handle pydantic models
    if issubclass(cls, pydantic.BaseModel):
        try:
            return cls.model_validate(payload)
        except Exception as e:
            raise DeserializationError(
                f"Failed to deserialize data '{payload}' into Pydantic model '{value_type}'"
            ) from e

    # handle enum types
    if issubclass(cls, Enum):
        try:
            return cls[payload]
        except Exception as e:
            raise DeserializationError(f"Value '{payload}' is not a valid member of Enum '{value_type}'") from e

    # fallback: set attributes on a blank instance
    deserialized_payload = {k: _deserialize_value(v) for k, v in payload.items()}
    instance = cls.__new__(cls)
    for attr_name, attr_value in deserialized_payload.items():
        setattr(instance, attr_name, attr_value)
    return instance
